{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W7rEsKyWcxmu"
      },
      "source": [
        "##### Copyright 2019 The TF-Agents Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nQnmcm0oI1Q-"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G6aOV15Wc4HP"
      },
      "source": [
        "### CheckpointerとPolicySaver\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/10_checkpointer_policysaver_tutorial\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb\">     <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colabで実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M3HE5S3wsMEh"
      },
      "source": [
        "## はじめに\n",
        "\n",
        "`tf_agents.utils.common.Checkpointer`は、ローカルストレージとの間でトレーニングの状態、ポリシーの状態、およびreplay_bufferの状態を保存/読み込むユーティリティです。\n",
        "\n",
        "`tf_agents.policies.policy_saver.PolicySaver`は、ポリシーのみを保存/読み込むツールであり、`Checkpointer`よりも軽量です。`PolicySaver`を使用すると、ポリシーを作成したコードに関する知識がなくてもモデルをデプロイできます。\n",
        "\n",
        "このチュートリアルでは、DQNを使用してモデルをトレーニングし、次に`Checkpointer`と`PolicySaver`を使用して、状態とモデルをインタラクティブな方法で保存および読み込む方法を紹介します。`PolicySaver`では、TF2.0の新しいsaved_modelツールとフォーマットを使用することに注意してください。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vbTrDrX4dkP_"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Opk_cVDYdgct"
      },
      "source": [
        "以下の依存関係をインストールしていない場合は、実行します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jv668dKvZmka"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!sudo apt-get install -y xvfb ffmpeg\n",
        "!pip install 'gym==0.10.11'\n",
        "!pip install 'imageio==2.4.0'\n",
        "!pip install 'pyglet==1.3.2'\n",
        "!pip install 'xvfbwrapper==0.2.9'\n",
        "!pip install tf-agents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bQMULMo1dCEn"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import base64\n",
        "import imageio\n",
        "import io\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import shutil\n",
        "import tempfile\n",
        "import tensorflow as tf\n",
        "import zipfile\n",
        "import IPython\n",
        "\n",
        "try:\n",
        "  from google.colab import files\n",
        "except ImportError:\n",
        "  files = None\n",
        "from tf_agents.agents.dqn import dqn_agent\n",
        "from tf_agents.drivers import dynamic_step_driver\n",
        "from tf_agents.environments import suite_gym\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.eval import metric_utils\n",
        "from tf_agents.metrics import tf_metrics\n",
        "from tf_agents.networks import q_network\n",
        "from tf_agents.policies import policy_saver\n",
        "from tf_agents.policies import py_tf_eager_policy\n",
        "from tf_agents.policies import random_tf_policy\n",
        "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n",
        "from tf_agents.trajectories import trajectory\n",
        "from tf_agents.utils import common\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()\n",
        "\n",
        "tempdir = os.getenv(\"TEST_TMPDIR\", tempfile.gettempdir())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AwIqiLdDCX9Q"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "# Set up a virtual display for rendering OpenAI gym environments.\n",
        "import xvfbwrapper\n",
        "xvfbwrapper.Xvfb(1400, 900, 24).start()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AOv_kofIvWnW"
      },
      "source": [
        "## DQNエージェント\n",
        "\n",
        "前のColabと同じように、DQNエージェントを設定します。 このColabでは、詳細は主な部分ではないので、デフォルトでは非表示になっていますが、「コードを表示」をクリックすると詳細を表示できます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cStmaxredFSW"
      },
      "source": [
        "### ハイパーパラメーター"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "yxFs6QU0dGI_"
      },
      "outputs": [],
      "source": [
        "env_name = \"CartPole-v1\"\n",
        "\n",
        "collect_steps_per_iteration = 100\n",
        "replay_buffer_capacity = 100000\n",
        "\n",
        "fc_layer_params = (100,)\n",
        "\n",
        "batch_size = 64\n",
        "learning_rate = 1e-3\n",
        "log_interval = 5\n",
        "\n",
        "num_eval_episodes = 10\n",
        "eval_interval = 1000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w4GR7RDndIOR"
      },
      "source": [
        "### 環境"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fZwK4d-bdI7Z"
      },
      "outputs": [],
      "source": [
        "train_py_env = suite_gym.load(env_name)\n",
        "eval_py_env = suite_gym.load(env_name)\n",
        "\n",
        "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n",
        "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0AvYRwfkeMvo"
      },
      "source": [
        "### エージェント"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "cUrFl83ieOvV"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "q_net = q_network.QNetwork(\n",
        "    train_env.observation_spec(),\n",
        "    train_env.action_spec(),\n",
        "    fc_layer_params=fc_layer_params)\n",
        "\n",
        "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n",
        "\n",
        "global_step = tf.compat.v1.train.get_or_create_global_step()\n",
        "\n",
        "agent = dqn_agent.DqnAgent(\n",
        "    train_env.time_step_spec(),\n",
        "    train_env.action_spec(),\n",
        "    q_network=q_net,\n",
        "    optimizer=optimizer,\n",
        "    td_errors_loss_fn=common.element_wise_squared_loss,\n",
        "    train_step_counter=global_step)\n",
        "agent.initialize()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p8ganoJhdsbn"
      },
      "source": [
        "### データ収集"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "XiT1p78HdtSe"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
        "    data_spec=agent.collect_data_spec,\n",
        "    batch_size=train_env.batch_size,\n",
        "    max_length=replay_buffer_capacity)\n",
        "\n",
        "collect_driver = dynamic_step_driver.DynamicStepDriver(\n",
        "    train_env,\n",
        "    agent.collect_policy,\n",
        "    observers=[replay_buffer.add_batch],\n",
        "    num_steps=collect_steps_per_iteration)\n",
        "\n",
        "# Initial data collection\n",
        "collect_driver.run()\n",
        "\n",
        "# Dataset generates trajectories with shape [BxTx...] where\n",
        "# T = n_step_update + 1.\n",
        "dataset = replay_buffer.as_dataset(\n",
        "    num_parallel_calls=3, sample_batch_size=batch_size,\n",
        "    num_steps=2).prefetch(3)\n",
        "\n",
        "iterator = iter(dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8V8bojrKdupW"
      },
      "source": [
        "### エージェントのトレーニング"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "-rDC3leXdvm_"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n",
        "agent.train = common.function(agent.train)\n",
        "\n",
        "def train_one_iteration():\n",
        "\n",
        "  # Collect a few steps using collect_policy and save to the replay buffer.\n",
        "  for _ in range(collect_steps_per_iteration):\n",
        "    collect_driver.run()\n",
        "\n",
        "  # Sample a batch of data from the buffer and update the agent's network.\n",
        "  experience, unused_info = next(iterator)\n",
        "  train_loss = agent.train(experience)\n",
        "\n",
        "  iteration = agent.train_step_counter.numpy()\n",
        "  print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vgqVaPnUeDAn"
      },
      "source": [
        "### ビデオ生成"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "ZY6w-fcieFDW"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "def embed_gif(gif_buffer):\n",
        "  \"\"\"Embeds a gif file in the notebook.\"\"\"\n",
        "  tag = '<img src=\"data:image/gif;base64,{0}\"/>'.format(base64.b64encode(gif_buffer).decode())\n",
        "  return IPython.display.HTML(tag)\n",
        "\n",
        "def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):\n",
        "  num_episodes = 3\n",
        "  frames = []\n",
        "  for _ in range(num_episodes):\n",
        "    time_step = eval_tf_env.reset()\n",
        "    frames.append(eval_py_env.render())\n",
        "    while not time_step.is_last():\n",
        "      action_step = policy.action(time_step)\n",
        "      time_step = eval_tf_env.step(action_step.action)\n",
        "      frames.append(eval_py_env.render())\n",
        "  gif_file = io.BytesIO()\n",
        "  imageio.mimsave(gif_file, frames, format='gif', fps=60)\n",
        "  IPython.display.display(embed_gif(gif_file.getvalue()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y-oA8VYJdFdj"
      },
      "source": [
        "### ビデオ生成\n",
        "\n",
        "ビデオを生成して、ポリシーのパフォーマンスを確認します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FpmPLXWbdG70"
      },
      "outputs": [],
      "source": [
        "print ('global_step:')\n",
        "print (global_step)\n",
        "run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7RPLExsxwnOm"
      },
      "source": [
        "## チェックポインタとPolicySaverのセットアップ\n",
        "\n",
        "CheckpointerとPolicySaverを使用する準備ができました。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g-iyQJacfQqO"
      },
      "source": [
        "### Checkpointer\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2DzCJZ-6YYbX"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = os.path.join(tempdir, 'checkpoint')\n",
        "train_checkpointer = common.Checkpointer(\n",
        "    ckpt_dir=checkpoint_dir,\n",
        "    max_to_keep=1,\n",
        "    agent=agent,\n",
        "    policy=agent.policy,\n",
        "    replay_buffer=replay_buffer,\n",
        "    global_step=global_step\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MKpWNZM4WE8d"
      },
      "source": [
        "### Policy Saver"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8mDZ_YMUWEY9"
      },
      "outputs": [],
      "source": [
        "policy_dir = os.path.join(tempdir, 'policy')\n",
        "tf_policy_saver = policy_saver.PolicySaver(agent.policy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1OnANb1Idx8-"
      },
      "source": [
        "### 1回のイテレーションのトレーニング"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ql_D1iq8dl0X"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "print('Training one iteration....')\n",
        "train_one_iteration()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eSChNSQPlySb"
      },
      "source": [
        "### チェックポイントに保存"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "usDm_Wpsl0bu"
      },
      "outputs": [],
      "source": [
        "train_checkpointer.save(global_step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gTQUrKgihuic"
      },
      "source": [
        "### チェックポイントに復元\n",
        "\n",
        "チェックポイントに復元するためには、チェックポイントが作成されたときと同じ方法でオブジェクト全体を再作成する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l6l3EB-Yhwmz"
      },
      "outputs": [],
      "source": [
        "train_checkpointer.initialize_or_restore()\n",
        "global_step = tf.compat.v1.train.get_global_step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nb8_MSE2XjRp"
      },
      "source": [
        "また、ポリシーを保存して指定する場所にエクスポートします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3xHz09WCWjwA"
      },
      "outputs": [],
      "source": [
        "tf_policy_saver.save(policy_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mz-xScbuh4Vo"
      },
      "source": [
        "ポリシーの作成に使用されたエージェントまたはネットワークについての知識がなくても、ポリシーを読み込めるので、ポリシーのデプロイが非常に簡単になります。\n",
        "\n",
        "保存されたポリシーを読み込み、それがどのように機能するかを確認します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J6T5KLTMh9ZB"
      },
      "outputs": [],
      "source": [
        "saved_policy = tf.compat.v2.saved_model.load(policy_dir)\n",
        "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MpE0KKfqjc0c"
      },
      "source": [
        "## エクスポートとインポート\n",
        "\n",
        "以下は、Checkpointerとポリシーディレクトリのエクスポート/インポートに役立ち、後でトレーニングを継続して、再度トレーニングすることなくモデルをデプロイできます。\n",
        "\n",
        "「1回のイテレーションのトレーニング」に戻り、後で違いを理解できるように、さらに数回トレーニングします。 結果が少し改善し始めたら、以下に進みます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "fd5Cj7DVjfH4"
      },
      "outputs": [],
      "source": [
        "#@title Create zip file and upload zip file (double-click to see the code)\n",
        "def create_zip_file(dirname, base_filename):\n",
        "  return shutil.make_archive(base_filename, 'zip', dirname)\n",
        "\n",
        "def upload_and_unzip_file_to(dirname):\n",
        "  if files is None:\n",
        "    return\n",
        "  uploaded = files.upload()\n",
        "  for fn in uploaded.keys():\n",
        "    print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
        "        name=fn, length=len(uploaded[fn])))\n",
        "    shutil.rmtree(dirname)\n",
        "    zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')\n",
        "    zip_files.extractall(dirname)\n",
        "    zip_files.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hgyy29doHCmL"
      },
      "source": [
        "チェックポイントディレクトリからzipファイルを作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nhR8NeWzF4fe"
      },
      "outputs": [],
      "source": [
        "train_checkpointer.save(global_step)\n",
        "checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VGEpntTocd2u"
      },
      "source": [
        "zipファイルをダウンロードします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "upFxb5k8b4MC"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "if files is not None:\n",
        "  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VRaZMrn5jLmE"
      },
      "source": [
        "10〜15回ほどトレーニングした後、チェックポイントのzipファイルをダウンロードし、[ランタイム]> [再起動してすべて実行]に移動してトレーニングをリセットし、このセルに戻ります。ダウンロードしたzipファイルをアップロードして、トレーニングを続けます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kg-bKgMsF-H_"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "upload_and_unzip_file_to(checkpoint_dir)\n",
        "train_checkpointer.initialize_or_restore()\n",
        "global_step = tf.compat.v1.train.get_global_step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uXrNax5Zk3vF"
      },
      "source": [
        "チェックポイントディレクトリをアップロードしたら、「1回のイテレーションのトレーニング」に戻ってトレーニングを続けるか、「ビデオ生成」に戻って読み込まれたポリシーのパフォーマンスを確認します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OAkvVZ-NeN2j"
      },
      "source": [
        "または、ポリシー（モデル）を保存して復元することもできます。Checkpointerとは異なり、トレーニングを続けることはできませんが、モデルをデプロイすることはできます。ダウンロードしたファイルはCheckpointerのファイルよりも大幅に小さいことに注意してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s7qMn6D8eiIA"
      },
      "outputs": [],
      "source": [
        "tf_policy_saver.save(policy_dir)\n",
        "policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rrGvCEXwerJj"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "if files is not None:\n",
        "  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DyC_O_gsgSi5"
      },
      "source": [
        "ダウンロードしたポリシーディレクトリ（exported_policy.zip）をアップロードし、保存したポリシーの動作を確認します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bgWLimRlXy5z"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "upload_and_unzip_file_to(policy_dir)\n",
        "saved_policy = tf.compat.v2.saved_model.load(policy_dir)\n",
        "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HSehXThTm4af"
      },
      "source": [
        "## SavedModelPyTFEagerPolicy\n",
        "\n",
        "TFポリシーを使用しない場合は、`py_tf_eager_policy.SavedModelPyTFEagerPolicy`を使用して、Python envでsaved_modelを直接使用することもできます。\n",
        "\n",
        "これは、eagerモードが有効になっている場合にのみ機能することに注意してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iUC5XuLf1jF7"
      },
      "outputs": [],
      "source": [
        "eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(\n",
        "    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())\n",
        "\n",
        "# Note that we're passing eval_py_env not eval_env.\n",
        "run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "10_checkpointer_policysaver_tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
